#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 10 15:19:55 2020

@author: matthewroberts
"""

import boto3
from boto3.session import Session
import numpy as np
import numpy.ma as ma
import sys, os
from netCDF4 import Dataset
import pandas as pd

# import imageio
# import cv2
import glob

import datetime as dt
from scipy.ndimage.filters import minimum_filter, maximum_filter
from scipy import spatial
from scipy.signal import savgol_filter
from scipy.interpolate import griddata

import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cartopy.io.shapereader as shpreader
from cartopy.mpl.ticker import (LongitudeFormatter, LatitudeFormatter,
                                LatitudeLocator, LongitudeLocator)

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.dates as mdates

from shapely.geometry import Point
from shapely.geometry.polygon import Polygon

import warnings
warnings.filterwarnings('ignore')

start_time = dt.datetime.now()
print('\nProgram Started: {0}'.format(start_time))
print('===================')

##########################################
# Inputs
##########################################
satname = 'goes17' #goes16 or goes17
# Fire product
product = 'ABI-L2-FDCC'
mode = 'M6' #usually M6 but M3 for older data

# Start/end time (YYYY, M, d, H)

# #bear
# keyword = 'bear_fire' #bear_fire
# sDATE = dt.datetime(2020, 9, 8, 19)
# eDATE = dt.datetime(2020, 9, 9, 4)
# # Region to plot (llon,rlon,llat,ulat)
# extent = [-121.3, -120.76, 39.6, 39.95] # bear fire

#caldor
keyword = 'caldor_fire' #caldor_fire
sDATE = dt.datetime(2021, 8, 17, 15)
eDATE = dt.datetime(2021, 8, 18, 0)
# Region to plot (llon,rlon,llat,ulat)
extent = [-120.7, -120.0, 38.45, 38.84] # caldor fire

# Make plots or not
DoPlot = False
# Create video of all plots when finished
CreateVideo = False

# Directories to use. Will be created if they do not exst
mainpath = '/Users/matthewroberts/Documents/Projects/LEAPHI'
filepath = mainpath+'/data/goes_data'
savepath = mainpath+'/plots/goes_plots/'+keyword

# AWS credentials
ACCESS_KEY='xxxxx'
SECRET_KEY='xxxxx'

##########################################
# Verify local directories
##########################################
print('\nSetting up...')
print('Mainpath: '+mainpath)
print('Filepath: '+filepath)
print('Savepath: '+savepath)
# Check if directories exist, if not make them
if not os.path.exists(mainpath):
    os.makedirs(mainpath)
if not os.path.exists(filepath):
    os.makedirs(filepath)
if not os.path.exists(savepath):
    os.makedirs(savepath)

##########################################
# Set up functions
##########################################
def latlon_convert(proj_info,lat_rad_1d,lon_rad_1d):
    """
    Remaps native projection and radials from GOES to lat/lon format
    """
    lon_origin = proj_info.longitude_of_projection_origin
    H = proj_info.perspective_point_height+proj_info.semi_major_axis
    r_eq = proj_info.semi_major_axis
    r_pol = proj_info.semi_minor_axis

    # create meshgrid filled with radian angles
    lat_rad,lon_rad = np.meshgrid(lat_rad_1d,lon_rad_1d)

    # lat/lon calc routine from satellite radian angle vectors
    lambda_0 = (lon_origin*np.pi)/180.0

    a_var = np.power(np.sin(lat_rad),2.0) + (np.power(np.cos(lat_rad),2.0)*(np.power(np.cos(lon_rad),2.0)+(((r_eq*r_eq)/(r_pol*r_pol))*np.power(np.sin(lon_rad),2.0))))
    b_var = -2.0*H*np.cos(lat_rad)*np.cos(lon_rad)
    c_var = (H**2.0)-(r_eq**2.0)

    r_s = (-1.0*b_var - np.sqrt((b_var**2)-(4.0*a_var*c_var)))/(2.0*a_var)

    s_x = r_s*np.cos(lat_rad)*np.cos(lon_rad)
    s_y = - r_s*np.sin(lat_rad)
    s_z = r_s*np.cos(lat_rad)*np.sin(lon_rad)

    # latitude and longitude projection for plotting data on traditional lat/lon maps
    lat = (180.0/np.pi)*(np.arctan(((r_eq*r_eq)/(r_pol*r_pol))*((s_z/np.sqrt(((H-s_x)*(H-s_x))+(s_y*s_y))))))
    lon = (lambda_0 - np.arctan(s_y/(H-s_x)))*(180.0/np.pi)

    return lat, lon

def plumeProp(latVar,lonVar,coords,sens):
    """
    Defines area to look at plume/fire properties.
    latVar/lonVar = x/y variable names
    coords = [lon,lat]  #Centroid of plume area
    sens = [lon,lat]    #Distance from centroid to draw box (deg)
    """
    # Will fill with "good" calculated indexes
    lonIdx = []
    latIdx = []
    # Loop through each row
    for idx in range(len(latVar[:,0])):
        # Verify data is there before doing operations.
        # Useful for fringes of satellite viewing area
        if (isinstance(latVar[idx,0], (np.floating, float)) == True):
            # Find matching coords in row and append to appropriate lists
            for i, (a, b) in enumerate(zip(latVar[idx,:], lonVar[idx,:])):
                if ((a > (coords[1]-sens[1])) & (a < (coords[1]+sens[1])) &
                    (b > (coords[0]-sens[0])) & (b < (coords[0]+sens[0]))):
                    #print(idx, i, a, b)
                    latIdx.append(i)
                    lonIdx.append(idx)
    # Get max/min of lists to find lowest/highest indexes for box corners (lon1=x,lat1=y)
    lat1,lat2 = np.min(np.asarray(latIdx)), np.max(np.asarray(latIdx))
    lon1,lon2 = np.min(np.asarray(lonIdx)), np.max(np.asarray(lonIdx))

    return lon1,lon2,lat1,lat2

# Creates shaded relief background on maps
from cartopy.io.img_tiles import GoogleTiles
class ShadedReliefESRI(GoogleTiles):
    # shaded relief
    def _image_url(self, tile):
        x, y, z = tile
        url = ('https://server.arcgisonline.com/ArcGIS/rest/services/' \
               'World_Shaded_Relief/MapServer/tile/{z}/{y}/{x}.jpg').format(
               z=z, y=y, x=x)
        return url

##########################################
# Set up dates
##########################################
print('\n{0} UTC  -->  {1} UTC'.format(sDATE,eDATE))
# Create a list of datetimes we want to download with Pandas 'date_range' function.
# Clip last value so date range ends at specified hour instead of +1 hr
date_list = pd.date_range(sDATE, eDATE, freq='1H')[:-1]
# Used for creating ticklabels in plot later
ticks_list = pd.date_range(sDATE, eDATE, periods=10)

##########################################
# AWS credentials and open session
##########################################
session = Session(aws_access_key_id=ACCESS_KEY,
                  aws_secret_access_key=SECRET_KEY)
s3 = session.client('s3')

FRP_times1 = []

_FRP = []
_cloud = []
_sat = []
_fire = []

#%%
for i in range(len(date_list)):
    # Name of AWS bucket
    bucket = 'noaa-'+satname

    # Filepath within bucket to get files
    prefix = product + '/' + str(date_list[i].year) + '/' + str(date_list[i].day_of_year).zfill(3) + '/' + str(date_list[i].hour).zfill(2) + '/OR_' + product + '-' + mode
    # Create file list to prepare for download
    filelist = s3.list_objects_v2(Bucket=bucket, Prefix=prefix)['Contents']

    # Go through filelist and download each file
    for key in range(len(filelist)):
        #print(key['Key'])
        print('\nDownloading files from AWS...', end='', flush=True)
        s3.download_file(bucket, filelist[key]['Key'], filepath+'/satfile.nc')

        filename = filepath+'/satfile.nc'
        data = Dataset(filename,'r')

        ##########################################
        # Fire Radiative Power
        ##########################################
        # Scan time
        midpoint = float(data.variables['t'][:])
        scan_mid = dt.datetime(2000,1,1,12) + dt.timedelta(seconds=midpoint)
        print(dt.datetime.strftime(scan_mid,'%Y%m%d_%H%M'))

        # GOES-17 projection info and retrieving relevant constants
        projInfo = data.variables['goes_imager_projection']
        latRad = data.variables['x'][:]
        lonRad = data.variables['y'][:]

        # Create x,y arrays from projection and radian data
        y, x = latlon_convert(projInfo,latRad,lonRad)
        # data and saturation info
        fireFRP = data.variables['Power'][:]
        fireFlags = data.variables['Mask'][:]

        cloudmask = np.zeros(fireFRP.shape)
        satmask = np.zeros(fireFRP.shape)
        firemask = np.zeros(fireFRP.shape)

        # Set cloud contaminated pixel flags (12,32) to mean FRP of good fire pixels
        fireFRP[fireFlags == 12] = np.ma.mean(fireFRP)
        fireFRP[fireFlags == 32] = np.ma.mean(fireFRP)
        # flag cloud contaminated pixels
        cloudmask[fireFlags == 12] = 99
        cloudmask[fireFlags == 32] = 99
        
        # Set saturated pixel flags (11,31) to max FRP of good fire pixels
        fireFRP[fireFlags == 11] = fireFRP.max()
        fireFRP[fireFlags == 31] = fireFRP.max()
        # flag saturated pixels
        satmask[fireFlags == 11] = 99
        satmask[fireFlags == 31] = 99
        
        # flag fire pixels
        firemask[fireFRP > 0] = 99

        # close file when finished
        data.close()
        data = None
        
        # Append all FRP data
        _FRP.append(fireFRP)
        # times
        FRP_times1.append(scan_mid)
        # arrays of flags
        _cloud.append(cloudmask)
        _sat.append(satmask)
        _fire.append(firemask)

# create arrays with time axis (3D)
_FRP = np.stack(_FRP, axis=0)
Cmask = np.stack(_cloud, axis=0)
Smask = np.stack(_sat, axis=0)
Fmask = np.stack(_fire, axis=0) #total includes cloud/saturation pixels
# times in seconds since first scan
FRP_times = (np.asarray(FRP_times1)-FRP_times1[0]).astype("timedelta64[m]").astype(int) * 60.

# Save timeseries for use later
# np.save(filepath+'/totFRP', np.asarray(FRP_timeseries))
# np.save(filepath+'/FRPtimes', np.asarray(FRP_times))

print('\nCropping domain...')
centroid = [((extent[1]-extent[0])/2.)+extent[0],((extent[3]-extent[2])/2.)+extent[2]]
xvals1FRP,xvals2FRP,yvals1FRP,yvals2FRP = plumeProp(y,x,centroid,[.4,.4])
# New arrays with data only from region of interest
fireX = x[xvals1FRP:xvals2FRP,yvals1FRP:yvals2FRP]
fireY = y[xvals1FRP:xvals2FRP,yvals1FRP:yvals2FRP]
_FRP = _FRP[:,xvals1FRP:xvals2FRP,yvals1FRP:yvals2FRP]
Cmask = Cmask[:,xvals1FRP:xvals2FRP,yvals1FRP:yvals2FRP]
Smask = Smask[:,xvals1FRP:xvals2FRP,yvals1FRP:yvals2FRP]
Fmask = Fmask[:,xvals1FRP:xvals2FRP,yvals1FRP:yvals2FRP]
FRP_timeseries = []
for ii in range(len(_FRP[:,0,0])):
    FRP_timeseries.append(np.ma.sum(np.ma.masked_where((_FRP[ii,:,:] == -9.), _FRP[ii,:,:])))
    
#%%
Fsum = np.sum(Fmask,axis=0)
# where is there at least 1 fire pixel
cfire = np.argwhere(Fsum > 0) 
# subset only valid fire data
frp22 = _FRP[:,cfire[:,0],cfire[:,1]]
# cloud mask for stats later
cld_msk = Cmask[:,cfire[:,0],cfire[:,1]]
# saturation mask for stats later
sat_msk = Smask[:,cfire[:,0],cfire[:,1]]
# get lat/lon coords of fire pixels
xx, yy = x[cfire[:,0],cfire[:,1]], y[cfire[:,0],cfire[:,1]]

# make no FRP nan instead of -9
newFRP = np.ma.masked_less_equal(frp22, 0) 
newFRP = newFRP.filled(np.nan) # bad data=nan

# Extrapolate nans/zeros with end points (don't use)
# FRP_arr = []
# for ii in range(len(newFRP[0,:])):
#     # interpolate/fill interior nans
#     df = pd.DataFrame(newFRP[:,ii])
#     interp_df = df.interpolate()
#     FRP_arr.append(interp_df)
# # make filled FRPs into array
# FRP_arr = np.stack(FRP_arr, axis=1).squeeze()
# # fill edge case nans (might cause negative frps near zeros when filtered)
# FRP_arr = np.nan_to_num(FRP_arr,nan=0)

# make NaNs = 0
FRP_arr = np.nan_to_num(newFRP,nan=0)
# function for moving avg, x=array, w=window size
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

FRParr2 = []
Tf = []
times = []
for pp in range(len(FRP_arr[0,:])):
    # apply filter to remove spikes/dips
    # window=13 (~60 min), poly = 3rd order fit
    # FRP_filt = savgol_filter(FRP_arr[:,pp], 3, 2) 
    FRP_filt = FRP_arr[:,pp] # unfiltered data
    # calc moving average to remove spikes/dips
    FRP_filt = moving_average(FRP_filt,3) # window=3, 3x5min=15 min avg
    # find index of max FRP in filtered array
    maxFRPidx = np.nanargmax(FRP_filt)
    # e^-1 threshold for burnout time
    thresh = .36*FRP_filt[maxFRPidx]
    # all points under threshold after max
    below_thresh = FRP_filt[maxFRPidx:] <= thresh
    # where is threshold first crossed after max?
    cross_idx = np.argmax(below_thresh, axis=0)
    
    # list of time indices
    t_idx = range(len(FRP_filt))
    if (cross_idx > 0):
        # find e^-1 index in master time array
        e_time = t_idx[maxFRPidx:][cross_idx]
        Tf.append(FRP_times[e_time]-FRP_times[maxFRPidx])
    else:
        Tf.append(np.nan)
    
    # list of max indices for centering later
    times.append(maxFRPidx)
    FRParr2.append(FRP_filt/FRP_filt[maxFRPidx]) #normalize

FRP_arr = np.stack(FRParr2, axis=1)
# make burnout time array
Tf = np.asarray(Tf)

newFRP = []
# add nan fillers to start/end for averaging
for t in range(len(times)):
    # add nans to end
    end = np.zeros(times[t])
    end[end == 0] = np.nan
    # add nans to beginning
    b = np.max(times)-times[t]
    beg = np.zeros(b)
    beg[beg == 0] = np.nan
    
    # prepend
    aa = np.insert(FRP_arr[:,t], 0, beg)
    # append
    bb = np.append(aa, end)
    # all dimensions now match for averaging
    newFRP.append(bb)

# normalize arrays at 0 seconds since max frp
# mean array
mean = np.nanmean(newFRP, axis=0)
# index of mean/median maximum
maxidx = np.nanargmax(mean, axis=0)
#time indexes to add <0, flip array so 0 is at end, crop last zero for overlap
begTime = np.arange(0,maxidx,1)[::-1][:-1]*-300.
# time indexes to add >0
endTime = np.arange(maxidx,len(mean)+1)*300.-(maxidx*300.)

# master time array, manual adjust for cropping 0 before
newFRPtimes = np.concatenate((begTime, endTime))-300.
newFRP = np.stack(newFRP, axis=0)

# interquartile ranges
q75, q25 = np.nanpercentile(newFRP, [75, 25], axis=0)

Tf_mean = np.nanmean(Tf)
# print(np.nanmedian(Tf))

w = .85*Tf_mean

cloudCt = np.count_nonzero(Cmask == 99)
satCt = np.count_nonzero(Smask == 99)
fireCt = np.count_nonzero(Fmask == 99)

# Plotting
######################
fig = plt.figure(1,figsize=(10,5))
spec = gridspec.GridSpec(nrows=40, ncols=40, figure=fig)
ax1 = fig.add_subplot(spec[:,:]) #rows,cols

valid_means = []
# plotting line plots
for jj in range(len(newFRP[:,0])):
    if (np.isnan(Tf[jj]) == False):
        # ax1.plot(newFRPtimes,newFRP[jj,:], color='lightcoral', alpha=.3)
        # create list of valid means
        valid_means.append(newFRP[jj,:])
    # plt.plot(newFRPtimes,newFRP[jj,:], color='grey', alpha=.4)
valid_means = np.stack(valid_means, axis=0)
mean = np.nanmean(valid_means, axis=0)

# 0-index of means <= .36 (first instance)
Tf_mean2 = newFRPtimes[mean <= .36]
print(Tf_mean2)
Tf_mean2 = Tf_mean2[Tf_mean2 > 0][0]
w2 = Tf_mean2 * .85

# ax1.plot(newFRPtimes,q25, color='k', linewidth=4, zorder=12)
# ax1.plot(newFRPtimes,q75, color='k', linewidth=4, zorder=12)
ax1.fill_between(newFRPtimes, q25, q75, color='lightcoral')
ax1.plot(newFRPtimes,mean, color='maroon', linewidth=4, zorder=12)
# ax1.axvline(x=0, color='k', linestyle='--')
ax1.axvline(x=Tf_mean2, color='k', linestyle='--', zorder=12)
ax1.axhline(y=.36, color='k', linestyle='--', zorder=12)

ax1.set_xlim(-5000,20000)
ax1.set_ylim(0,1)
ax1.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, bottom=True, top=True, zorder=99)
ax1.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, left=True, right=True, zorder=99)

# Set border around plot
ax1.spines['top'].set_zorder(12)
ax1.spines['bottom'].set_zorder(12)
ax1.spines['left'].set_zorder(12)
ax1.spines['right'].set_zorder(12)
ax1.spines['top'].set_linewidth(2)
ax1.spines['bottom'].set_linewidth(2)
ax1.spines['left'].set_linewidth(2)
ax1.spines['right'].set_linewidth(2)
    
ax1.set_xlabel('Normalized Time [sec]', fontsize=12, fontweight='heavy')
ax1.set_ylabel(r'Normalized FRP', fontsize=12, fontweight='heavy')

    
    
# inset map
ax2 = fig.add_subplot(spec[2:20,24:39], projection=ccrs.PlateCarree())
ax2.set_extent(extent)

# Create a Terrain instance.
esri_terrain = ShadedReliefESRI()
# Add the data at zoom level 12 (max).
ax2.add_image(esri_terrain, 12)

Tf2d = np.zeros(_FRP[0,:,:].shape)
Tf2d[:] = np.nan
for f in range(len(cfire[:,0])):
    Tf2d[cfire[f,0],cfire[f,1]] = Tf[f]
    
# divide by 60/60 to get hrs
p = ax2.pcolormesh(fireX, fireY, Tf2d/60./60., cmap='YlOrRd', vmin=0, vmax=6,
                   transform=ccrs.PlateCarree(), alpha=.8, zorder=9)

ax2.add_feature(cfeature.LAND, facecolor='k', alpha=.2, zorder=8)

ax2.set_xticks(np.linspace(extent[0]-.03,extent[1]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set longitude indicators
ax2.set_yticks(np.linspace(extent[2]-.03,extent[3]+.03,5)[1:-1],crs=ccrs.PlateCarree()) # set latitude indicators
lon_formatter = LongitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}',dateline_direction_label=True) # format lons
lat_formatter = LatitudeFormatter(number_format='0.2f',degree_symbol=u'\N{DEGREE SIGN}') # format lats
ax2.xaxis.set_major_formatter(lon_formatter) # set lons
ax2.yaxis.set_major_formatter(lat_formatter) # set lats
ax2.xaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, rotation=20, bottom=True, top=True, zorder=99)
ax2.yaxis.set_tick_params(width=2,length=5,direction='in',labelsize=11, left=True, right=True, zorder=99)
ax2.outline_patch.set_linewidth(2)
ax2.outline_patch.set_zorder(99)

# Colorbar
cax = fig.add_subplot(spec[5:18,35])
cbar = fig.colorbar(p, cax=cax, orientation='vertical')
ticks = [0,2,4,6]
cbar.set_ticks(ticks)
cbar.set_label(r'Burnout Time $\mathregular{[T_{f}, hrs]}$', fontsize=8, fontweight='heavy')
cbar.ax.xaxis.set_label_position('bottom') #if horizontal

# # stat box
# ax3 = fig.add_subplot(spec[22:31,26:39])
# ax3.set_xticks([]) 
# ax3.set_yticks([]) 

print('\n'+keyword)
print('Burnout and weighting:')
print('---------------')
print('Tf = '+str(Tf_mean2)+' seconds')
print('w = '+str(w2))

print('\nPixel Totals:')
print('---------------')
print('Cloud Mask: '+str(cloudCt)+' / '+str(int(cloudCt/fireCt*100))+'%')
print('Saturation Mask: '+str(satCt)+' / '+str(int(satCt/fireCt*100))+'%')
print('Fire pixels: '+str(fireCt))

plt.show()
sys.exit()

figname = 'weighting_func_dist_IQ'+'_'+keyword
plt.savefig('/Users/matthewroberts/Documents/Projects/LEAPHI/Writing/Paper_figs/'+figname+'.png',bbox_inches='tight',dpi=300)
plt.close('all')


